%This script runs the simulations used for Figure 5
%Corresponding author: Rodrigo Adao
%Date: 09/11/2024
%Input: model_data_fgkk.mat
%Output: Figures 5

%% Preliminaries

clear all;
close all;

%Define vector of large varieties in our sample
trade_target = 0.85; %baseline has 0.85

%Simulation parameters
Nsimul = 1500; %Number of simulations (baseline 1500)

%% Import Data
%set local paths
function_path = "..\Functions\";
addpath(function_path,'-frozen');
set_local_paths;

%Import data and compute researcher predictions
sigma_alt = 4;
simulation_preliminary_steps;

%% Create IV matrices 

 %Naive IV matrices
    var_adj=sum(ww_dW.^2)/Nobs;
    share_NC_nGE  = shareIV_tau;
    share_wNC_nGE  = ww_dWmat_naive*share_NC_nGE; 
    share_wMC_nGE  = ( share_wNC_nGE - Cn*(MCn_term1*share_wNC_nGE) )/var_adj;
    clearvars share_wNC_nGE share_NC_nGE shareIV_tau

%Auxiliary matrices for control and welfare adjustments
%load auxiliary adjustment vectors
load(save_share_path + "simulation_shares_t" + trade_target + ".mat", 'adjGE_wMC')

    adjGE_wMC_mat= spdiags(adjGE_wMC, 0 , Nobs, Nobs );
    share_wMC_dWa = adjGE_wMC_mat*shareMOD_dW;
    share_wMC_dW = share_wMC_dWa - C*(MC_term1*share_wMC_dWa);
    clearvars share_MC_dWa share_wMC_dWa adjGE_MC_mat adjGE_wMC_mat ww_dWmat ww_dWmat_naive

%% Compute equilibrium given different draws of parameters

Nv = 21; %number of partitions in linear combination of znaive and zmain 

%Misspecification parameters
gamma_grid_plot = 10;
Ngamma = length(gamma_grid_plot);

theta = 2.53;

%Shock parameters
avg_sh = .02; %average tariff shock
std_sh = .06; %st dev of tariff shock
std_ep = .06; %st dev of other shocks

%Parameters for simulation draws
sd_shifters_grid   = std_sh*ones(1*Ngamma,1);
mean_shifters_grid = avg_sh*ones(1*Ngamma,1);

%parameters for draws of shocks to other parameters from normal distribution
mean_deta_grid     = 0*ones(1*Ngamma,1); 
sd_da_grid         = std_ep*ones(1*Ngamma,1);
sd_dzstar_grid     = std_ep*ones(1*Ngamma,1);
sd_dastar_grid     = std_ep*ones(1*Ngamma,1);

%Parameters controlling true DGP
sigma_grid      = theta./gamma_grid_plot ;     
omega_star_grid = omega_star*ones(1*Ngamma,1);  
eta_grid        = eta       *ones(1*Ngamma,1);    
kappa_grid      = kappa     *ones(1*Ngamma,1);      
sigma_star_grid = sigma_star*ones(1*Ngamma,1);
rho_exp_grid    = 1         *ones(1*Ngamma,1);
rho_imp_grid    = 1         *ones(1*Ngamma,1);

gamma_mean_px_grid = zeros(1*Ngamma,1);
gamma_mean_pm_grid = zeros(1*Ngamma,1);
gamma_mean_rm_grid = zeros(1*Ngamma,1);

gamma_sd_px_grid = zeros(size(gamma_mean_px_grid));
gamma_sd_pm_grid = zeros(size(gamma_mean_pm_grid));
gamma_sd_rm_grid = zeros(size(gamma_mean_rm_grid));

gammaM_grid = ones(1*Ngamma,1);
gammaX_grid = ones(1*Ngamma,1);
gammaQ_grid = gamma_grid_plot;

Npar = 1;

Npar
ests_theta = zeros(3, Nsimul, Npar);
ests       = zeros(5, Nsimul, Npar);
est_zc = zeros(Nv*4, Nsimul, Npar);
est_hat_zc = zeros(Nv*6, Nsimul, Npar);

for j=1:Npar
    tic
    display('------------- running new j -----------------')
    j
    
    %shock process
    sd_shifters  = sd_shifters_grid(j);
    mean_shifters = mean_shifters_grid(j);
    adj = Nobs*(mean_shifters/sd_shifters);

    mean_deta = mean_deta_grid(j);
    sd_da     = sd_da_grid(j);
    sd_dzstar = sd_dzstar_grid(j);
    sd_dastar = sd_dastar_grid(j);

    %Generate true DGP   
    %Set elasticities in the model
    omega_starDGP = omega_star_grid(j);    
    sigmaDGP = sigma_grid(j);           
    etaDGP = eta_grid(j);               
    kappaDGP = kappa_grid(j);           
    sigma_starDGP = sigma_star_grid(j); 
    rhoMDGP = rho_imp_grid(j) ;
    rhoXDGP = rho_exp_grid(j) ;
    gammaMDGP = gammaM_grid(j) ;
    gammaXDGP = gammaX_grid(j) ;
    gammaQDGP = gammaQ_grid(j) ;

    gamma_n = (indM==1).*( normrnd(gamma_mean_pm_grid(j), gamma_sd_pm_grid(j), Nobs, 1) )... 
            + (indT==1).*( normrnd(gamma_mean_rm_grid(j), gamma_sd_rm_grid(j), Nobs, 1) )...
            + (indX==1).*( normrnd(gamma_mean_px_grid(j), gamma_sd_px_grid(j), Nobs, 1) );
    gammaDGP = spdiags(1+ gamma_n, 0, Nobs,Nobs);

    [delta_ig_DGP0, a_star_ig_DGP0, z_star_ig_DGP0, a_ig_DGP0, aM_g_DGP0, AM_s_DGP0, AD_s_DGP0, betaT_DGP0, beta_s_DGP0, alphaL_s_DGP0, alphaI_s_DGP0, alpha_ks_DGP0, barL_rs_DGP0, Z_rs_DGP0, D_DGP0, E_s_DGP0, F_DGP0] = ... 
        invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
        Dsg, DX, DM, omega_starDGP, sigmaDGP, etaDGP, kappaDGP, sigma_starDGP, gammaMDGP, gammaXDGP,  tol_parm);
 

    [p_s_IV, PM_s_IV, ...
         rhoDGP_qm_ig_tau_ig, rhoDGP_qm_ig_taustar_ig, rhoDGP_qm_ig_a_ig, rhoDGP_qm_ig_zstar_ig, rhoDGP_qm_ig_astar_ig, rhoDGP_qm_ig_delta_ig, ...
         rhoDGP_pm_ig_tau_ig, rhoDGP_pm_ig_taustar_ig, rhoDGP_pm_ig_a_ig, rhoDGP_pm_ig_zstar_ig, rhoDGP_pm_ig_astar_ig, rhoDGP_pm_ig_delta_ig, ...
         rhoDGP_px_ig_tau_ig, rhoDGP_px_ig_taustar_ig, rhoDGP_px_ig_a_ig, rhoDGP_px_ig_zstar_ig, rhoDGP_px_ig_astar_ig, rhoDGP_px_ig_delta_ig, ...
         rhoDGP_pstar_ig_tau_ig, rhoDGP_pstar_ig_taustar_ig, rhoDGP_pstar_ig_a_ig, rhoDGP_pstar_ig_zstar_ig, rhoDGP_pstar_ig_astar_ig, rhoDGP_pstar_ig_delta_ig, ...
         rhoDGP_rm_ig_tau_ig, rhoDGP_rm_ig_taustar_ig, rhoDGP_rm_ig_a_ig, rhoDGP_rm_ig_zstar_ig, rhoDGP_rm_ig_astar_ig, rhoDGP_rm_ig_delta_ig] ...
         = compute_equilibrium_foa_full( delta_ig_DGP0, a_star_ig_DGP0, z_star_ig_DGP0, a_ig_DGP0, aM_g_DGP0, AM_s_DGP0, AD_s_DGP0, betaT_DGP0, beta_s_DGP0, alphaL_s_DGP0, alphaI_s_DGP0, alpha_ks_DGP0, barL_rs_DGP0, Z_rs_DGP0, D_DGP0, ... 
             tau_ig_0, tau_star_ig_0, m_ig_0, E_s_DGP0, p_s_0, Dsg, DX, DM, omega_starDGP, sigmaDGP, etaDGP, kappaDGP, sigma_starDGP, nu, epsilon, gammaQDGP, gammaXDGP, gammaMDGP, rhoXDGP, rhoMDGP, tol_conv, adj_inner_g0, adj_outter_g0, ... 
             ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig ); 

    shareDGP_pm      = [rhoDGP_pm_ig_tau_ig    , rhoDGP_pm_ig_taustar_ig   ];
    shareDGP_rm      = [rhoDGP_rm_ig_tau_ig    , rhoDGP_rm_ig_taustar_ig   ];
    shareDGP_px      = [rhoDGP_px_ig_tau_ig    , rhoDGP_px_ig_taustar_ig   ];
    shareDGP_pstar   = [rhoDGP_pstar_ig_tau_ig , rhoDGP_pstar_ig_taustar_ig];
    shareDGP_qm      = [rhoDGP_qm_ig_tau_ig    , rhoDGP_qm_ig_taustar_ig   ];

    shareDGP_dW = gammaDGP*[shareDGP_pm; shareDGP_rm; shareDGP_px];

    shareDGP_dW_da     = [rhoDGP_pm_ig_a_ig    ; rhoDGP_rm_ig_a_ig    ; rhoDGP_px_ig_a_ig    ];
    shareDGP_dW_dzstar = [rhoDGP_pm_ig_zstar_ig; rhoDGP_rm_ig_zstar_ig; rhoDGP_px_ig_zstar_ig];
    shareDGP_dW_dastar = [rhoDGP_pm_ig_astar_ig; rhoDGP_rm_ig_astar_ig; rhoDGP_px_ig_astar_ig];
    

    clear    rhoDGP_px_ig_tau_ig rhoDGP_px_ig_taustar_ig rhoDGP_px_ig_a_ig rhoDGP_px_ig_zstar_ig rhoDGP_px_ig_astar_ig rhoDGP_px_ig_delta_ig ...
             rhoDGP_pstar_ig_tau_ig rhoDGP_pstar_ig_taustar_ig rhoDGP_pstar_ig_a_ig rhoDGP_pstar_ig_zstar_ig rhoDGP_pstar_ig_astar_ig rhoDGP_pstar_ig_delta_ig ...
             rhoDGP_rm_ig_tau_ig rhoDGP_rm_ig_taustar_ig rhoDGP_rm_ig_a_ig rhoDGP_rm_ig_zstar_ig rhoDGP_rm_ig_astar_ig rhoDGP_rm_ig_delta_ig ...
             shareDGP_rm_dtar shareDGP_px_dtar shareDGP_pstar_dtar shareDGP_px shareDGP_rm ...
             rhoDGP_qm_ig_delta_ig rhoDGP_qm_ig_tau_ig  rhoDGP_qm_ig_taustar_ig ...
             delta_ig_DGP0 a_star_ig_DGP0 z_star_ig_DGP0 a_ig_DGP0 aM_g_DGP0 barL_rs_DGP0 Z_rs_DGP0 
            
    toc
    
    %% Compute jacobian of predictions around researcher's model
tic
    eps = 0.01;
 
        [delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, E_s_0, F_0] = ... 
            invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
            Dsg, DX, DM, omega_star, theta+eps, eta, kappa, sigma_star, gammaM, gammaX,  tol_parm);
            
        [p_s_IV, PM_s_IV, ...
                 rho_qm_ig_tau_ig, rho_qm_ig_taustar_ig, rho_qm_ig_a_ig, rho_qm_ig_zstar_ig, rho_qm_ig_astar_ig, rho_qm_ig_delta_ig, ...
                 rho_pm_ig_tau_ig, rho_pm_ig_taustar_ig, rho_pm_ig_a_ig, rho_pm_ig_zstar_ig, rho_pm_ig_astar_ig, rho_pm_ig_delta_ig, ...
                 rho_px_ig_tau_ig, rho_px_ig_taustar_ig, rho_px_ig_a_ig, rho_px_ig_zstar_ig, rho_px_ig_astar_ig, rho_px_ig_delta_ig, ...
                 rho_pstar_ig_tau_ig, rho_pstar_ig_taustar_ig, rho_pstar_ig_a_ig, rho_pstar_ig_zstar_ig, rho_pstar_ig_astar_ig, rho_pstar_ig_delta_ig, ...
                 rho_rm_ig_tau_ig, rho_rm_ig_taustar_ig, rho_rm_ig_a_ig, rho_rm_ig_zstar_ig, rho_rm_ig_astar_ig, rho_rm_ig_delta_ig] ...
                 = compute_equilibrium_foa_full( delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, ... 
                 tau_ig_0, tau_star_ig_0, m_ig_0, E_s_0, p_s_0, Dsg, DX, DM, omega_star, theta, eta, kappa, sigma_star, nu, epsilon, gammaQ, gammaX, gammaM, rhoX, rhoM, tol_conv, adj_inner_g0, adj_outter_g0, ... 
                 ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig );  
    
        share_pm    = [rho_pm_ig_tau_ig    , rho_pm_ig_taustar_ig   ];
        share_px    = [rho_px_ig_tau_ig    , rho_px_ig_taustar_ig   ];
        share_rm    = [rho_rm_ig_tau_ig    , rho_rm_ig_taustar_ig   ];
        shareMOD_dW_sigma0    = sparse( [share_pm; share_rm; share_px] );
 
        [delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, E_s_0, F_0] = ... 
            invert_param_full(x_ig_0, m_ig_0, tau_star_ig_0, tau_ig_0, trade_con_share, tot_labor_comp, tot_intermediate, sales_s, IO_sales, Lsr, ...
            Dsg, DX, DM, omega_star, sigma+eps, eta, kappa, sigma_star, gammaM, gammaX,  tol_parm);
            
        [p_s_IV, PM_s_IV, ...
                 rho_qm_ig_tau_ig, rho_qm_ig_taustar_ig, rho_qm_ig_a_ig, rho_qm_ig_zstar_ig, rho_qm_ig_astar_ig, rho_qm_ig_delta_ig, ...
                 rho_pm_ig_tau_ig, rho_pm_ig_taustar_ig, rho_pm_ig_a_ig, rho_pm_ig_zstar_ig, rho_pm_ig_astar_ig, rho_pm_ig_delta_ig, ...
                 rho_px_ig_tau_ig, rho_px_ig_taustar_ig, rho_px_ig_a_ig, rho_px_ig_zstar_ig, rho_px_ig_astar_ig, rho_px_ig_delta_ig, ...
                 rho_pstar_ig_tau_ig, rho_pstar_ig_taustar_ig, rho_pstar_ig_a_ig, rho_pstar_ig_zstar_ig, rho_pstar_ig_astar_ig, rho_pstar_ig_delta_ig, ...
                 rho_rm_ig_tau_ig, rho_rm_ig_taustar_ig, rho_rm_ig_a_ig, rho_rm_ig_zstar_ig, rho_rm_ig_astar_ig, rho_rm_ig_delta_ig] ...
                 = compute_equilibrium_foa_full( delta_ig_0, a_star_ig_0, z_star_ig_0, a_ig_0, aM_g_0, AM_s_0, AD_s_0, betaT_0, beta_s_0, alphaL_s_0, alphaI_s_0, alpha_ks_0, barL_rs_0, Z_rs_0, D_0, ... 
                 tau_ig_0, tau_star_ig_0, m_ig_0, E_s_0, p_s_0, Dsg, DX, DM, omega_star, theta +eps, eta, kappa, sigma_star, nu, epsilon, gammaQ, gammaX, gammaM, rhoX, rhoM, tol_conv, adj_inner_g0, adj_outter_g0, ... 
                 ind_Mshifts_vec, ind_Xshifts_vec, ind_Msample_ig, ind_Xsample_ig );  
    
        share_pm    = [rho_pm_ig_tau_ig    , rho_pm_ig_taustar_ig   ];
        share_px    = [rho_px_ig_tau_ig    , rho_px_ig_taustar_ig   ];
        share_rm    = [rho_rm_ig_tau_ig    , rho_rm_ig_taustar_ig   ];
        shareMOD_dW_sigmah    = sparse( [share_pm; share_rm; share_px] );

        shareMOD_dW_dsigma  = (shareMOD_dW_sigmah  - shareMOD_dW_sigma0 )/eps;
        
        clear    rho_qm_ig_tau_ig rho_qm_ig_taustar_ig rho_qm_ig_a_ig rho_qm_ig_zstar_ig rho_qm_ig_astar_ig rho_qm_ig_delta_ig ...
             rho_pm_ig_tau_ig rho_pm_ig_taustar_ig rho_pm_ig_a_ig rho_pm_ig_zstar_ig rho_pm_ig_astar_ig rho_pm_ig_delta_ig ...
             rho_px_ig_tau_ig rho_px_ig_taustar_ig rho_px_ig_a_ig rho_px_ig_zstar_ig rho_px_ig_astar_ig rho_px_ig_delta_ig ...
             rho_pstar_ig_tau_ig rho_pstar_ig_taustar_ig rho_pstar_ig_a_ig rho_pstar_ig_zstar_ig rho_pstar_ig_astar_ig rho_pstar_ig_delta_ig ...
             rho_rm_ig_tau_ig rho_rm_ig_taustar_ig rho_rm_ig_a_ig rho_rm_ig_zstar_ig rho_rm_ig_astar_ig rho_rm_ig_delta_ig ...
             share_qm share_pm share_pstar share_px share_rm shareMOD_dW_sigmah  ...
             z_star_ig_0 delta_ig_0 a_ig_0 a_star_ig_0

        toc
    
        tic
    %Simulation across draws of shocks
    parfor n=1:Nsimul    
    %for n=1:Nsimul
    %n
    %% simulated data
    % simulated data
        %simulated shifts
        dtau_sim = normrnd(mean_shifters, sd_shifters, NMs, 1);
        dtaustar_sim = normrnd(mean_shifters, sd_shifters, NXs, 1);  
        shifters = [dtau_sim; dtaustar_sim];
        shiftersIV = (shifters - mean(shifters))/std(shifters);
        
        %simulated shocks
        da_sim     = normrnd(mean_deta, sd_da    , NMs, 1);
        dzstar_sim = normrnd(mean_deta, sd_dzstar, NMs, 1);
        dastar_sim = normrnd(mean_deta, sd_dastar, NXs, 1);

        detaDGP = shareDGP_dW_da*da_sim + shareDGP_dW_dzstar*dzstar_sim + shareDGP_dW_dastar*dastar_sim ;

        %Simulated predictions and outcomes with true parameters
        dx     = shareMOD_dW*shifters;
        dxstar = shareDGP_dW*shifters ;
        dy = dxstar + detaDGP;
        dyt_NC = dy - dx;
        dyn_NC = dyt_NC(sample_naive);

        coef_dyt=(MC_term1*dyt_NC);
        dyt_MC = dyt_NC - C*coef_dyt;
        coef_dyn=(MCn_term1*dyn_NC);
        dyn_MC = dyn_NC - Cn*coef_dyn;

        dytM = dyn_NC(indM==1);
        dytT = dyn_NC(indT==1);
        dytX = dyn_NC(indX==1);
        dyt_Mig_resg   = dytM(sampleM_naive)    - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*dytM(sampleM_naive)  );
        dyt_Tig_resg   = dytT(sampleM_naive)    - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*dytT(sampleM_naive)  );
        dyt_Xig_resg   = dytX(sampleX_naive)    - D_Xsample_g_ig'*( diag(1./sum(D_Xsample_g_ig, 2))*D_Xsample_g_ig*dytX(sampleX_naive)  );
        dyn_MCg = [dyt_Mig_resg; dyt_Tig_resg; dyt_Xig_resg];

        % Welfare statistics
        dW = ww_dW'*dx;    
        dWstar = ww_dW'*dxstar;
        Delta_W = dWstar - dW;

        %simulated data for estimation
        deta_dlqm = rhoDGP_qm_ig_a_ig*da_sim + rhoDGP_qm_ig_zstar_ig*dzstar_sim + rhoDGP_qm_ig_astar_ig*dastar_sim ; 
        deta_dlpm = rhoDGP_pm_ig_a_ig*da_sim + rhoDGP_pm_ig_zstar_ig*dzstar_sim + rhoDGP_pm_ig_astar_ig*dastar_sim ; 

        dlqm_ig     = shareDGP_qm*shifters      + deta_dlqm;
        dlpm_ig     = shareDGP_pm*shifters      + deta_dlpm;
       
        dltau_ig = dtau_sim(exp_sind == 0)./(1+tau_ig_0_sample);
        shifters_IVimp = shiftersIV(exp_sind==0);
        IVimp_est = shareM_MCg_nGE*shifters_IVimp ;  
   
        dlqm_ig_resg    = dlqm_ig     - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*dlqm_ig    );
        dlpm_ig_resg    = dlpm_ig     - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*dlpm_ig    );
         
    %% Parameter estimation 

         hat_sigma = - (dlqm_ig_resg'*IVimp_est   )/(dlpm_ig_resg'*IVimp_est); 
         hat_nu_sigma = dlqm_ig_resg + hat_sigma*dlpm_ig_resg;
         Dsigma_nu = dlpm_ig_resg;
         Dsigma_h  = (Dsigma_nu'*IVimp_est)/NM  ;
         Rh = hat_nu_sigma'*shareM_MCg_nGE / NM; 
         %Rh = hat_nu_sigma'*shareM_wMCg_nGE / NM; 
         hat_Vhh = (Rh.^2)*(shifters_IVimp.^2);
         SE_sigma = ( hat_Vhh/(Dsigma_h^2) ).^(1/2);   

         hath_theta = IVimp_est'*(dlqm_ig_resg + theta*dlpm_ig_resg)/NM;
            
        %Estimation output
        Dtheta_h = Dsigma_h;
        hat_theta = hat_sigma ;
        SE_theta = SE_sigma;
        ests_theta_n = [ hat_theta, SE_theta, hath_theta];
        
    %% Researcher predictions with estimated parameters
        
        shareMOD_hat_dW = shareMOD_dW_sigma0 + shareMOD_dW_dsigma*(hat_sigma - theta);

        %Compute variables for testing
        dx_hat     = shareMOD_hat_dW*shifters;
        dyt_hat_NC = dy - dx_hat;
        dyn_hat_NC = dyt_hat_NC(sample_naive);

        coef_dyt=(MC_term1*dyt_hat_NC);
        dyt_hat_MC = dyt_hat_NC - C*coef_dyt;
        coef_dyn=(MCn_term1*dyn_hat_NC);
        dyn_hat_MC = dyn_hat_NC - Cn*coef_dyn;

        dytM = dyn_hat_NC(indM==1);
        dytT = dyn_hat_NC(indT==1);
        dytX = dyn_hat_NC(indX==1);
        dyt_hat_Mig_resg   = dytM(sampleM_naive)    - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*dytM(sampleM_naive)  );
        dyt_hat_Tig_resg   = dytT(sampleM_naive)    - D_Msample_g_ig'*( diag(1./sum(D_Msample_g_ig, 2))*D_Msample_g_ig*dytT(sampleM_naive)  );
        dyt_hat_Xig_resg   = dytX(sampleX_naive)    - D_Xsample_g_ig'*( diag(1./sum(D_Xsample_g_ig, 2))*D_Xsample_g_ig*dytX(sampleX_naive)  );
        dyn_hat_MCg = [dyt_hat_Mig_resg; dyt_hat_Tig_resg; dyt_hat_Xig_resg];

        %Compute jacobian of predictions for inference
        Dsigma_dx = shareMOD_dW_dsigma*shifters;   
        Dtheta_dx = - Dsigma_dx;       
        
        %Welfare stats
        dW_hat = ww_dW'*dx_hat;
        Delta_W_hat = dWstar - dW_hat;
        
         est_n = [dWstar, dW, dW_hat, Delta_W, Delta_W_hat];

         %% Implement test with different IVs
         naive_wt = linspace(0,1,Nv);
         zn_MCg  = share_MCg_nGE*shiftersIV;
         zw_wMC = adj*share_wMC_dW*shiftersIV;

         est_zc_n = zeros(Nv,4); 
         est_hat_zc_n = zeros(Nv,6); 
         for v = 1:Nv
            %IV
            zc_MC_IV = naive_wt(v)*zn_MCg*(std(zw_wMC)/std(zn_MCg)) + (1-naive_wt(v))*zw_wMC;
            share_zc_MC_IV = naive_wt(v)*share_MCg_nGE*(std(zw_wMC)/std(zn_MCg))  + (1-naive_wt(v))* adj*share_wMC_dW;
            
            %Test based on prior
            [beta_zc_MC_IV  , SE_zc_MC_IV, rej_zc_MC_IV, rej0_zc_MC_IV] = implement_test(dyt_NC, zc_MC_IV, share_zc_MC_IV, shiftersIV, critical, 0);
            est_zc_n(v,:) = [beta_zc_MC_IV, SE_zc_MC_IV, rej_zc_MC_IV, rej0_zc_MC_IV];
        
            %Test based on estimated model
            [beta_zc_MC_IV  , SE_zc_MC_IV, rej_zc_MC_IV, rej0_zc_MC_IV, info_zc_MC_IV, info0_zc_MC_IV] = implement_jointtest(dyt_hat_NC , zc_MC_IV, share_zc_MC_IV, shiftersIV, shifters_IVimp, hat_Vhh, Rh, Dtheta_h, Dtheta_dx, exp_sind, critical, 0);       
            est_hat_zc_n(v,:) = [beta_zc_MC_IV, SE_zc_MC_IV, rej_zc_MC_IV, rej0_zc_MC_IV, info_zc_MC_IV, info0_zc_MC_IV];
         end        
            
    %% save outcomes for step n
        ests_theta(:,n,j) = ests_theta_n';
        ests(:,n,j) = est_n';
        est_zc(:,n,j) = reshape(est_zc_n, Nv*4,1);
        est_hat_zc(:,n,j) = reshape(est_hat_zc_n, Nv*6,1);
        
    end
    toc
     save(save_simulation_path + "output_Fig_5.mat", 'ests_theta', 'ests', 'est_zc', 'est_hat_zc', 'Npar', 'Nv', 'gamma_grid_plot', 'trade_target');
end


%% Report results
load(save_simulation_path + "output_Fig_5.mat")

for j = 1:Npar
    ests_est_j    = ests_theta(:,:,j);
    ests_j        = ests(:,:,j);
    ests_zc_j     = est_zc(:,:,j);
    ests_hat_zc_j = est_hat_zc(:,:,j);
    
    %Estimation
    hat_theta_j = ests_est_j(1,:);
    SE_theta_j  = ests_est_j(2,:);
    hath_theta_j = ests_est_j(3,:)'; 

    %test with true parameter
    estimates_j = ests_j(1:5,:);

    %Joint estimation/test 
    estimates_hat_zc_j = ests_hat_zc_j(0*Nv+1:1*Nv,:);
    SE_hat_zc    = [mean(ests_hat_zc_j(1*Nv+1:2*Nv,:),2)];
    rej_hat_zc   = [mean(ests_hat_zc_j(2*Nv+1:3*Nv,:),2)];
    rej0_hat_zc  = [mean(ests_hat_zc_j(3*Nv+1:4*Nv,:),2)];
    info_hat_zc  = [mean(ests_hat_zc_j(4*Nv+1:5*Nv,:),2)];
    info0_hat_zc = [mean(ests_hat_zc_j(5*Nv+1:6*Nv,:),2)];
    
    estimates_zc_j = ests_zc_j(0*Nv+1:1*Nv,:);
    SE_zc    = [mean(ests_zc_j(1*Nv+1:2*Nv,:),2)];
    rej_zc   = [mean(ests_zc_j(2*Nv+1:3*Nv,:),2)];
    rej0_zc  = [mean(ests_zc_j(3*Nv+1:4*Nv,:),2)];

    ests_stat(:,:,j)     = [mean(estimates_j,2), std(estimates_j,1,2)];
    ests_zc_stat(:,:,j) = [mean(estimates_zc_j,2), std(estimates_zc_j,1,2), SE_zc, rej0_zc, rej_zc];
    ests_hat_zc_stat(:,:,j) = [mean(estimates_hat_zc_j,2), std(estimates_hat_zc_j,1,2), SE_hat_zc, rej0_hat_zc, rej_hat_zc,  info0_hat_zc,  info_hat_zc];
    
    %Compute info as R2 of regressing beta_hat on hath
    for v=1:Nv
        hat_beta_v = estimates_hat_zc_j(v,:)';
        c = corr([hat_beta_v, hath_theta_j]);
        infoR2(v,j) = c(2,1)^2;
    end
    
end

j=1;
corr([infoR2(:,j), ests_hat_zc_stat(:,6:7,j)])
[infoR2(:,j), ests_hat_zc_stat(:,6:7,j), ests_hat_zc_stat(:,6,j)./infoR2(:,j)]

%% Figure 5b
%
title_size = 14;
label_size = 20;
legend_size = 16;
marker_size = 10;
axes_size = 18;

figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 600 600]);
plot_info_joint = tiledlayout(1,1);

ymin = 0 ;
ymax = 1;

xmin = 0;
xmax = 1;

j0 = Npar;

What = squeeze(ests_stat(3,1,j0));
DeltaW = squeeze(ests_stat(4,1,j0));
DeltaWhat = squeeze(ests_stat(5,1,j0));

beta_IV_series =  squeeze(ests_zc_stat(1,1,j0));
[What, DeltaW, DeltaWhat, beta_IV_series ]

x_axis = linspace(0,1,Nv);
rej0_series =  squeeze(ests_zc_stat(:,4,j0));
rej0_hat_series =  squeeze(ests_hat_zc_stat(:,4,j0));
info0_series =  squeeze(ests_hat_zc_stat(:,6,j0));

corr([rej0_hat_series,info0_series])

plot(x_axis, rej0_hat_series, 'Marker','o','MarkerFaceColor','black', 'MarkerSize', marker_size, 'Color', 'black')
hold on 
plot(x_axis, info0_series, 'Marker','s','MarkerFaceColor',[0.5 0.5 0.5], 'MarkerSize', marker_size, 'Color', [0.5 0.5 0.5])
ylim([ymin ymax])
yticks([0 .2 .4 .6 .8 1])
legend('Rejection rate of H0 at 5% with estimation', 'Informativeness of estimation moment', 'Location', 'northwest', 'FontSize', legend_size)
xlabel('Weight on Naive IV', 'FontSize', label_size)

saveas(plot_info_joint, graph_path+ "Fig_5b.png")

saveas(plot_info_joint, graph_path+ "Fig_5b.eps", 'epsc');

%%  Figure 5a

label_size = 20;
legend_size = 24;
marker_size = 10;
axes_size = 16;

y = [rej0_series(end), rej0_series(1); 
    rej0_hat_series(end), rej0_hat_series(1);
    info0_series(end), info0_series(1)];

figure('DefaultAxesFontSize',axes_size, 'Position', [10 10 900 600]);
plot_bar_joint = tiledlayout(1,1);

b = bar(y);
b(1).FaceColor = [0.5 0.5 0.5];
b(2).FaceColor = 'black';

yticks([0 .2 .4 .6 .8 1])
xticklabels({' Rejection rate \newline without estimation', ' Rejection rate \newline with estimation', 'Informativeness'})
legend('Naive IV', 'Preferred IV', 'Location', 'north', 'FontSize', legend_size)
%xlabel('Rejection rate      ', 'HorizontalAlignment','right')

ax = gca;
ax.XAxis.FontSize = 24;

saveas(plot_bar_joint, graph_path+ "Fig_5a.png")

saveas(plot_bar_joint, graph_path+ "Fig_5a.eps", 'epsc')